import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.models import BACKBONES

@BACKBONES.register_module ()
class UNet (nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super (UNet, self).__init__ ()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv (n_channels, 64))
        self.down1 = (Down (64, 128))
        self.down2 = (Down (128, 256))
        self.down3 = (Down (256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down (512, 1024 // factor))
        self.up1 = (Up (1024, 512 // factor, bilinear))
        self.up2 = (Up (512, 256 // factor, bilinear))
        self.up3 = (Up (256, 128 // factor, bilinear))
        self.up4 = (Up (128, 64, bilinear))
        self.outc = (OutConv (64, n_classes))

    def forward(self, x):
        # n_channels = 64 * 16
        # n_classes = 256
        # x = [B,64,200,200]
        x1 = self.inc (x)  # [1,64,200,200]
        x2 = self.down1 (x1)  # [1,128,100,100]
        x3 = self.down2 (x2)  # [1,256,50,50]
        x4 = self.down3 (x3)  # [1,512,25,25]
        x5 = self.down4 (x4)  # [1,1024,12,12]
        x = self.up1 (x5, x4)  # [1,512,25,25]
        x = self.up2 (x, x3)  # [1,256,50,50]
        x = self.up3 (x, x2)  # [1,128,100,100]
        x = self.up4 (x, x1)  # [1,64,200,200]
        logits = self.outc (x)  # [1,64,200,200]
        return logits



class DoubleConv (nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super ().__init__ ()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential (
            nn.Conv2d (in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d (mid_channels),
            nn.ReLU (inplace=True),
            nn.Conv2d (mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d (out_channels),
            nn.ReLU (inplace=True)
        )

    def forward(self, x):
        return self.double_conv (x)


class Down (nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super ().__init__ ()
        self.maxpool_conv = nn.Sequential (
            nn.MaxPool2d (2),
            DoubleConv (in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv (x)


class Up (nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super ().__init__ ()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample (scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv (in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d (in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv (in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up (x1)
        # input is CHW
        diffY = x2.size ()[2] - x1.size ()[2]
        diffX = x2.size ()[3] - x1.size ()[3]

        x1 = F.pad (x1, [diffX // 2, diffX - diffX // 2,
                         diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat ([x2, x1], dim=1)
        return self.conv (x)


class Custom_Up (nn.Module):
    """Upscaling then double conv"""

    def __init__(self, x1_channel, x2_channel, out_channels, bilinear=False):
        super ().__init__ ()

        # if bilinear, use the normal convolutions to reduce the number of channels
        # default bilinear=False
        if bilinear:
            raise ValueError ("bilinear=False")
        else:
            self.up = nn.ConvTranspose2d (x1_channel, x1_channel // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv (x1_channel // 2 + x2_channel, out_channels)

    def forward(self, x1, x2):
        x1 = self.up (x1)
        # input is CHW
        diffY = x2.size ()[2] - x1.size ()[2]
        diffX = x2.size ()[3] - x1.size ()[3]

        x1 = F.pad (x1, [diffX // 2, diffX - diffX // 2,
                         diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat ([x2, x1], dim=1)
        return self.conv (x)


class OutConv (nn.Module):
    def __init__(self, in_channels, out_channels):
        super (OutConv, self).__init__ ()
        self.conv = nn.Conv2d (in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv (x)

